#!/usr/bin/env python3
# H22 — Transitions → Lines (diagnostic)
# CONTROL: present-act, boolean/ordinal; one commit per tick; NO curves/weights; NO RNG in control.
# The system has two radial "levels" (A inner band, B outer band). A deterministic schedule
# allows level transitions at integer periods (P_AtoB, P_BtoA). When a transition fires, the
# level flips A↔B. We record a binary time series T[t]=1 when a transition occurs, else 0.
# DIAGNOSTICS: compute a simple periodogram (DFT at integer periods) and check that the ON panel
# exhibits strong lines at the predicted periods (1/P_AtoB, 1/P_BtoA), while the REF panel does not.

import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import Dict, List, Tuple

# ---------- utils ----------
def utc_ts() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root:str, subs:List[str]) -> None:
    for s in subs: os.makedirs(os.path.join(root, s), exist_ok=True)

def wtxt(path:str, txt:str) -> None:
    with open(path, "w", encoding="utf-8") as f: f.write(txt)

def jdump(path:str, obj:dict) -> None:
    with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=2, sort_keys=True)

def sha256_file(path:str) -> str:
    import hashlib
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""): h.update(chunk)
    return h.hexdigest()

def isqrt(n:int) -> int:
    return int(math.isqrt(n))

def modS(x:int, S:int) -> int:
    x %= S
    return x if x >= 0 else x + S

# ---------- geometry helpers ----------
def build_band_shells(r_min:int, r_max:int) -> Tuple[int,int]:
    """Return a representative shell (mid) and its width; used to place the commit (diagnostic only)."""
    r0 = int(r_min); r1 = int(r_max)
    return ((r0 + r1)//2, (r1 - r0 + 1))

# ---------- present-act control (one commit per tick) ----------
def simulate_panel(panel:str, H:int, S:int,
                   s0_A:int, step_A:int, s0_B:int, step_B:int,
                   P_AtoB:int, P_BtoA:int,
                   shell_A:int, shell_B:int) -> Dict[str, object]:
    """
    Minimal deterministic controller:
      - State includes (level ∈ {A,B}) and a sector index for each level that advances by step_X each tick.
      - If panel == "REF": transitions are disabled (no level flips).
      - If panel == "ON": at each tick t, if level==A and (t % P_AtoB == 0 and P_AtoB>0) → flip to B;
                           if level==B and (t % P_BtoA == 0 and P_BtoA>0) → flip to A.
      - Commit exactly one cell per tick: (shell_(level), sector_(level)).
    We record T[t]=1 if a transition fired at t, else 0. We also record which level committed.
    """
    level = "A"
    sA = int(s0_A)
    sB = int(s0_B)
    T = [0]*H
    L = ["A"]*H

    for t in range(H):
        # Transition policy
        if panel == "ON":
            if level == "A" and P_AtoB > 0 and (t % P_AtoB == 0):
                level = "B"; T[t] = 1
            elif level == "B" and P_BtoA > 0 and (t % P_BtoA == 0):
                level = "A"; T[t] = 1
        # REF: no flips; T[t] remains 0

        # Commit one act (diagnostic — we don't collect spatial maps here)
        if level == "A":
            L[t] = "A"
        else:
            L[t] = "B"

        # Update sectors deterministically (parity-safe steps)
        sA = modS(sA + step_A, S)
        sB = modS(sB + step_B, S)

    return {"T": T, "L": L}

# ---------- diagnostics: periodogram & fit ----------
def periodogram(T: List[int], H:int, max_period:int) -> List[Tuple[int,float]]:
    """
    Simple DFT magnitude at integer periods P=2..max_period (freq 2π/P).
    Returns list of (P, amplitude) with amplitude normalized by H.
    """
    amps = []
    for P in range(2, max_period+1):
        w = 2.0*math.pi / P
        c = 0.0; s = 0.0
        for t in range(H):
            ang = w * t
            c += T[t] * math.cos(ang)
            s += T[t] * math.sin(ang)
        A = math.sqrt(c*c + s*s) / H
        amps.append((P, A))
    return amps

def pick_value(amps: List[Tuple[int,float]], P:int) -> float:
    for p,a in amps:
        if p == P: return a
    return float("nan")

def snr_at(amps: List[Tuple[int,float]], P:int) -> float:
    """SNR = peak_at_P / median_of_others"""
    vals = [a for _,a in amps]
    if not vals or math.isnan(pick_value(amps, P)): return float("nan")
    peak = pick_value(amps, P)
    others = [a for p,a in amps if p != P]
    if not others: return float("inf")
    others_sorted = sorted(others)
    med = others_sorted[len(others_sorted)//2]
    return (peak / med) if med > 0 else float("inf")

def r2_two_line_fit(T: List[int], H:int, P1:int, P2:int) -> float:
    """R^2 of T ~ a0 + a1*cos(2πt/P1)+b1*sin(...) + a2*cos(2πt/P2)+b2*sin(...)."""
    if P1 < 2 or P2 < 2: return float("nan")
    mu = sum(T)/H
    # projections
    def proj(P):
        w = 2.0*math.pi / P
        c = sum((T[t]-mu)*math.cos(w*t) for t in range(H))
        s = sum((T[t]-mu)*math.sin(w*t) for t in range(H))
        return c, s
    c1,s1 = proj(P1); c2,s2 = proj(P2)
    # reconstruction
    rec = []
    for t in range(H):
        ang1 = 2.0*math.pi*t/P1
        ang2 = 2.0*math.pi*t/P2
        # normalization: orthonormality factor ~ H/2 for cos/sin over integer cycles
        r = mu + (c1*math.cos(ang1)+s1*math.sin(ang1))/(H/2.0) + (c2*math.cos(ang2)+s2*math.sin(ang2))/(H/2.0)
        rec.append(r)
    ss_tot = sum((T[t]-mu)*(T[t]-mu) for t in range(H))
    ss_res = sum((T[t]-rec[t])*(T[t]-rec[t]) for t in range(H))
    return 1.0 - (ss_res/ss_tot if ss_tot>0 else 0.0)

# ---------- run both panels ----------
def run_h22(M: dict) -> dict:
    H       = int(M["H"])
    S       = int(M["sectors"]["S"])
    bands   = M["bands"]
    ctrl    = M["control"]
    spec    = M["spectrum"]
    acc     = M["acceptance"]

    # representative shells (used only to define bands; control commits one act per tick regardless)
    shell_A,_ = build_band_shells(bands[0]["r_min"], bands[0]["r_max"])
    shell_B,_ = build_band_shells(bands[1]["r_min"], bands[1]["r_max"])

    # ON panel (transitions active)
    on = simulate_panel(
        panel="ON", H=H, S=S,
        s0_A=int(ctrl["sA0"]), step_A=int(ctrl["stepA"]),
        s0_B=int(ctrl["sB0"]), step_B=int(ctrl["stepB"]),
        P_AtoB=int(ctrl["P_A_to_B"]), P_BtoA=int(ctrl["P_B_to_A"]),
        shell_A=shell_A, shell_B=shell_B
    )
    # REF panel (no transitions)
    ref = simulate_panel(
        panel="REF", H=H, S=S,
        s0_A=int(ctrl["sA0"]), step_A=int(ctrl["stepA"]),
        s0_B=int(ctrl["sB0"]), step_B=int(ctrl["stepB"]),
        P_AtoB=0, P_BtoA=0,
        shell_A=shell_A, shell_B=shell_B
    )

    # periodogram
    maxP = int(spec["freq_scan_max_period"])
    amps_on  = periodogram(on["T"],  H, maxP)
    amps_ref = periodogram(ref["T"], H, maxP)

    P1 = int(ctrl["P_A_to_B"]); P2 = int(ctrl["P_B_to_A"])
    A1_on = pick_value(amps_on,  P1); A2_on = pick_value(amps_on,  P2)
    A1_rf = pick_value(amps_ref, P1); A2_rf = pick_value(amps_ref, P2)
    S1 = snr_at(amps_on,  P1); S2 = snr_at(amps_on,  P2)
    R2 = r2_two_line_fit(on["T"], H, P1, P2)

    # acceptance
    peak_min = float(acc["on"]["peak_min"])
    snr_min  = float(acc["on"]["snr_min"])
    r2_min   = float(acc["on"]["r2_min"])
    peak_max = float(acc["off"]["peak_max"])
    snr_max  = float(acc["off"]["snr_max"])
    r2_max   = float(acc["off"]["r2_max"])

    print(f"[ACC] on: peak_min={peak_min} snr_min={snr_min} r2_min={r2_min}")
    print(f"[ACC] off: peak_max={peak_max} snr_max={snr_max} r2_max={r2_max}")


    on_ok  = (A1_on >= peak_min) and (A2_on >= peak_min) and (S1 >= snr_min) and (S2 >= snr_min) and (R2 >= r2_min)
    off_ok = (A1_rf <= peak_max) and (A2_rf <= peak_max) and (r2_two_line_fit(ref["T"], H, P1, P2) <= r2_max) and (snr_at(amps_ref, P1) <= snr_max) and (snr_at(amps_ref, P2) <= snr_max)

    # Debug breakdown for off_ok
    off_a1_ok = (A1_rf <= peak_max)
    off_a2_ok = (A2_rf <= peak_max)
    off_r2_ok = (r2_two_line_fit(ref["T"], H, P1, P2) <= r2_max)
    snr1 = snr_at(amps_ref, P1); snr2 = snr_at(amps_ref, P2)
    off_snr1_ok = (snr1 <= snr_max)
    off_snr2_ok = (snr2 <= snr_max)
    print(f"[OFF] A1_rf={A1_rf} <= {peak_max} -> {off_a1_ok}")
    print(f"[OFF] A2_rf={A2_rf} <= {peak_max} -> {off_a2_ok}")
    print(f"[OFF] R2_ref <= {r2_max} -> {off_r2_ok}")
    print(f"[OFF] SNR1={snr1} <= {snr_max} -> {off_snr1_ok}")
    print(f"[OFF] SNR2={snr2} <= {snr_max} -> {off_snr2_ok}")



    print(f"[CHECK] on_ok={on_ok} off_ok={off_ok}")

    passed = bool(on_ok and off_ok)

    return {
        "on": {"A1": A1_on, "A2": A2_on, "S1": S1, "S2": S2, "R2": R2},
        "ref": {"A1": A1_rf, "A2": A2_rf},
        "amps_on": amps_on, "amps_ref": amps_ref,
        "P1": P1, "P2": P2, "pass": passed
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config","outputs/metrics","outputs/audits","outputs/run_info","logs"])

    with open(args.manifest, "r", encoding="utf-8") as f:
        M = json.load(f)
    man_out = os.path.join(root, "config", "manifest_h22.json");
    jdump(man_out, M)

    # env
    wtxt(os.path.join(root,"logs","env.txt"),
         "\n".join([f"utc={utc_ts()}", f"os={os.name}", f"cwd={os.getcwd()}",
                    f"python={sys.version.split()[0]}"]))

    # run
    aud = run_h22(M)

    # write metrics
    with open(os.path.join(root, "outputs/metrics", "h22_periodogram_on.csv"), "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f); w.writerow(["period","amplitude"])
        for P,A in aud["amps_on"]:  w.writerow([P, f"{A:.9f}"])
    with open(os.path.join(root, "outputs/metrics", "h22_periodogram_ref.csv"), "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f); w.writerow(["period","amplitude"])
        for P,A in aud["amps_ref"]: w.writerow([P, f"{A:.9f}"])

    # write audit
    jdump(os.path.join(root, "outputs", "audits", "h22_audit.json"), aud)

    # result line
    result = (f"H22 PASS={aud['pass']} A1_on={aud['on']['A1']:.3f} A2_on={aud['on']['A2']:.3f} S1={aud['on']['S1']:.2f} S2={aud['on']['S2']:.2f} R2={aud['on']['R2']:.3f} A1_ref={aud['ref']['A1']:.3f} A2_ref={aud['ref']['A2']:.3f}")



    wtxt(os.path.join(root, "outputs", "run_info", "result_line.txt"), result)
    print(result)

if __name__ == "__main__":
    main()